import torch

from torch import nn
from einops import rearrange
from Utils.Tool import index_points
from Utils.Registry import Registry
# from Backbones.Point_Learner import MovingGuidance
from Backbones.State_Space_Model import StateSpaceModel
from Utils.Serialization.default import hilbert_encode, z_order_encode

curves = Registry('SpatialFillingCurves')
ssm = Registry('StateSpaceModel')
statespace = Registry('ContextAwareStateSpace')


@curves.register_module('Hilbert')
class Hilbert(object):
    def __init__(self, grid_size):
        self.grid_size = grid_size

    def __call__(self, coordinates):
        batch = None if len(coordinates.shape) < 3 else self.offset2batch(coordinates)
        if batch is not None:
            coordinates = rearrange(coordinates, 'b n d -> (b n) d')
        grid_coord = torch.div(
            coordinates - coordinates.min(0)[0], self.grid_size, rounding_mode="trunc"
        ).int()
        depth = int(grid_coord.max()).bit_length()

        assert (depth * 3 + len(batch).bit_length() <= 63) & (depth <= 16)
        code = hilbert_encode(grid_coord, depth=depth)

        if batch is not None:
            batch = batch.long()
            code = rearrange(batch << depth * 3 | code, '(b n) -> b n', b=len(torch.unique(batch)))

        return code

    @torch.inference_mode()
    def offset2batch(self, coordinates):
        offset = coordinates.shape[1] * (torch.arange(coordinates.shape[0]) + 1)
        bincount = torch.diff(
            offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
        )
        return torch.arange(len(bincount), device=offset.device, dtype=torch.long).repeat_interleave(bincount).to(coordinates.device)


@curves.register_module('Z_Order')
class Z_Order(object):
    def __init__(self, grid_size):
        self.grid_size = grid_size

    def __call__(self, coordinates):
        batch = None if len(coordinates.shape) < 3 else self.offset2batch(coordinates)
        if batch is not None:
            coordinates = rearrange(coordinates, 'b n d -> (b n) d')
        grid_coord = torch.div(
            coordinates - coordinates.min(0)[0], self.grid_size, rounding_mode="trunc"
        ).int()
        depth = int(grid_coord.max()).bit_length()

        assert (depth * 3 + len(batch).bit_length() <= 63) & (depth <= 16)
        code = z_order_encode(grid_coord, depth=depth)

        if batch is not None:
            batch = batch.long()
            code = rearrange(batch << depth * 3 | code, '(b n) -> b n', b=len(torch.unique(batch)))

        return code

    @torch.inference_mode()
    def offset2batch(self, coordinates):
        offset = coordinates.shape[1] * (torch.arange(coordinates.shape[0]) + 1)
        bincount = torch.diff(
            offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
        )
        return torch.arange(len(bincount), device=offset.device, dtype=torch.long).repeat_interleave(bincount).to(coordinates.device)


@curves.register_module('withoutCurve')
class withoutCurve(object):
    def __call__(self, coordinates):
        batch, N, _ = coordinates.shape
        code = torch.arange(0, N).repeat(batch, 1, 1).to(coordinates.device)
        return code


@ssm.register_module('Bidirectional')
class BidirectionalSSM(nn.Module):
    def __init__(self, channels, d_state, expand):
        super(BidirectionalSSM, self).__init__()
        self.norm = nn.LayerNorm(channels)

        self.forwardSSM, self.backwardSSM = [
            StateSpaceModel(
                channels,
                d_state,
                expand
            ) for _ in range(2)
        ]

    def forward(self, features):
        features = self.norm(features)
        return self.forwardSSM(features) + self.backwardSSM(features.flip(dims=(-2,))).flip(dims=(-2,))


@ssm.register_module('Unidirectional')
class UnidirectionalSSM(nn.Module):
    def __init__(self, channels, d_state, expand):
        super(UnidirectionalSSM, self).__init__()
        self.norm = nn.LayerNorm(channels)

        self.forwardSSM = StateSpaceModel(channels, d_state, expand)

    def forward(self, features):
        return self.forwardSSM(self.norm(features))


@statespace.register_module()
class ContextAwareStateSpace(nn.Module):
    def __init__(self, channels, ssm_cfgs, bias=True):
        super(ContextAwareStateSpace, self).__init__()
        self.channels = channels

        ssm_cfgs.channels = channels
        self.ssm = ssm.build(ssm_cfgs)
        self.out_proj = nn.Linear(channels, channels, bias=bias)

    def forward(self, features, code):
        sorting = torch.argsort(code, dim=-1)
        retrieve = torch.argsort(sorting, dim=-1)
        features = index_points(self.ssm(index_points(features, sorting)), retrieve)
        return self.out_proj(features)
